package com.diven.spark.ml.learn.feature

import com.diven.spark.ml.learn.core.{BaseSpark, BaseTest}
import org.apache.spark.ml.feature.{ColumnCastType, ColumnFilter}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{DataType, DataTypes}

object ColumnFilterTest extends BaseTest{
  def apply(): BaseSpark = new ColumnFilterTest()
}

class ColumnFilterTest extends BaseSpark{

  override def execute(dataFrame: DataFrame) = {
    //特征名称
    var features = Array("weight", "height", "age")
//    var a = new ColumnFilter().setFilterColumns(features).transform(dataFrame)

//    new ColumnCastType().setCastColumns(Map("weight"-> DataTypes.StringType, "height" -> DataTypes.IntegerType))
//      .transform(a).printSchema()
//      .save("storage/ColumnCastType/")
//
    ColumnCastType.load("storage/ColumnCastType/").transform(dataFrame).printSchema()

//    new ColumnFilter().setFilterColumns(features).save("storage/ColumnFilter1/")
//    ColumnFilter.load("storage/ColumnFilter1/").transform(dataFrame)

//    dataFrame.show()


  }

}